{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = ''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/bert-standard/bert/optimization.py:87: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import bert\n",
    "from bert import optimization\n",
    "from bert import tokenization\n",
    "from bert import modeling\n",
    "import numpy as np\n",
    "import json\n",
    "import tensorflow as tf\n",
    "import itertools\n",
    "import collections\n",
    "import re\n",
    "import random\n",
    "import sentencepiece as spm\n",
    "from unidecode import unidecode\n",
    "from sklearn.utils import shuffle\n",
    "from prepro_utils import preprocess_text, encode_ids, encode_pieces\n",
    "from malaya.text.function import transformer_textcleaning as cleaning\n",
    "from tensorflow.python.estimator.run_config import RunConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('/home/husein/alxlnet/topics.json') as fopen:\n",
    "    topics = set(json.load(fopen).keys())\n",
    "\n",
    "list_topics = list(topics)\n",
    "\n",
    "sp_model = spm.SentencePieceProcessor()\n",
    "sp_model.Load('sp10m.cased.bert.model')\n",
    "\n",
    "with open('sp10m.cased.bert.vocab') as fopen:\n",
    "    v = fopen.read().split('\\n')[:-1]\n",
    "v = [i.split('\\t') for i in v]\n",
    "v = {i[0]: i[1] for i in v}\n",
    "\n",
    "\n",
    "class Tokenizer:\n",
    "    def __init__(self, v):\n",
    "        self.vocab = v\n",
    "        pass\n",
    "\n",
    "    def tokenize(self, string):\n",
    "        return encode_pieces(\n",
    "            sp_model, string, return_unicode = False, sample = False\n",
    "        )\n",
    "\n",
    "    def convert_tokens_to_ids(self, tokens):\n",
    "        return [sp_model.PieceToId(piece) for piece in tokens]\n",
    "\n",
    "    def convert_ids_to_tokens(self, ids):\n",
    "        return [sp_model.IdToPiece(i) for i in ids]\n",
    "\n",
    "\n",
    "tokenizer = Tokenizer(v)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def F(text):\n",
    "    tokens_a = tokenizer.tokenize(text)\n",
    "    tokens = ['[CLS]'] + tokens_a + ['[SEP]']\n",
    "    input_id = tokenizer.convert_tokens_to_ids(tokens)\n",
    "    input_mask = [1] * len(input_id)\n",
    "    return input_id, input_mask\n",
    "\n",
    "\n",
    "def XY(data):\n",
    "\n",
    "    if len(set(data[1]) & topics) and random.random() > 0.2:\n",
    "        t = random.choice(data[1])\n",
    "        label = 1\n",
    "    else:\n",
    "        s = set(data[1]) | set()\n",
    "        t = random.choice(list(topics - s))\n",
    "        label = 0\n",
    "    X = F(cleaning(data[0]))\n",
    "    Y = F(t)\n",
    "\n",
    "    return X, Y, label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('/home/husein/alxlnet/testset-keyphrase.json') as fopen:\n",
    "    data = json.load(fopen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.preprocessing.sequence import pad_sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_initializer(initializer_range = 0.02):\n",
    "    return tf.truncated_normal_initializer(stddev = initializer_range)\n",
    "\n",
    "\n",
    "def get_assignment_map_from_checkpoint(tvars, init_checkpoint):\n",
    "    \"\"\"Compute the union of the current variables and checkpoint variables.\"\"\"\n",
    "    assignment_map = {}\n",
    "    initialized_variable_names = {}\n",
    "\n",
    "    name_to_variable = collections.OrderedDict()\n",
    "    for var in tvars:\n",
    "        name = var.name\n",
    "        m = re.match('^(.*):\\\\d+$', name)\n",
    "        if m is not None:\n",
    "            name = m.group(1)\n",
    "        name_to_variable[name] = var\n",
    "\n",
    "    init_vars = tf.train.list_variables(init_checkpoint)\n",
    "\n",
    "    assignment_map = collections.OrderedDict()\n",
    "    for x in init_vars:\n",
    "        (name, var) = (x[0], x[1])\n",
    "        if 'bert/' + name not in name_to_variable:\n",
    "            continue\n",
    "        assignment_map[name] = name_to_variable['bert/' + name]\n",
    "        initialized_variable_names[name] = 1\n",
    "        initialized_variable_names[name + ':0'] = 1\n",
    "\n",
    "    return (assignment_map, initialized_variable_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 60\n",
    "warmup_proportion = 0.1\n",
    "num_train_steps = 1000000\n",
    "num_warmup_steps = int(num_train_steps * warmup_proportion)\n",
    "learning_rate = 2e-5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/bert-standard/bert/modeling.py:93: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "bert_config = modeling.BertConfig.from_json_file(\n",
    "    'tiny-bert-v1/config.json'\n",
    ")\n",
    "\n",
    "class Model:\n",
    "    def __init__(\n",
    "        self,\n",
    "        dimension_output = 2,\n",
    "    ):\n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.input_masks = tf.placeholder(tf.float32, [None, None])\n",
    "        \n",
    "        self.X_b = tf.placeholder(tf.int32, [None, None])\n",
    "        self.input_masks_b = tf.placeholder(tf.float32, [None, None])\n",
    "        \n",
    "        self.Y = tf.placeholder(tf.int32, [None])\n",
    "        \n",
    "        with tf.compat.v1.variable_scope('bert', reuse = False):\n",
    "            model = modeling.BertModel(\n",
    "                config = bert_config,\n",
    "                is_training = True,\n",
    "                input_ids = self.X,\n",
    "                input_mask = self.input_masks,\n",
    "                use_one_hot_embeddings = False,\n",
    "            )\n",
    "\n",
    "            summary = model.get_pooled_output()\n",
    "            summary = tf.identity(summary, name = 'summary')\n",
    "            self.summary = summary\n",
    "            \n",
    "        with tf.compat.v1.variable_scope('bert', reuse = True):\n",
    "            model = modeling.BertModel(\n",
    "                config = bert_config,\n",
    "                is_training = True,\n",
    "                input_ids = self.X_b,\n",
    "                input_mask = self.input_masks_b,\n",
    "                use_one_hot_embeddings = False,\n",
    "            )\n",
    "\n",
    "            summary_b = model.get_pooled_output()\n",
    "        \n",
    "        vectors_concat = [summary, summary_b, tf.abs(summary - summary_b)]\n",
    "        vectors_concat = tf.concat(vectors_concat, axis = 1)\n",
    "        \n",
    "        self.logits = tf.layers.dense(vectors_concat, dimension_output)\n",
    "        self.logits = tf.identity(self.logits, name = 'logits')\n",
    "        \n",
    "        self.cost = tf.reduce_mean(\n",
    "            tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
    "                logits = self.logits, labels = self.Y\n",
    "            )\n",
    "        )\n",
    "        \n",
    "        correct_pred = tf.equal(\n",
    "            tf.argmax(self.logits, 1, output_type = tf.int32), self.Y\n",
    "        )\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/bert-standard/bert/modeling.py:171: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/bert-standard/bert/modeling.py:409: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/bert-standard/bert/modeling.py:490: The name tf.assert_less_equal is deprecated. Please use tf.compat.v1.assert_less_equal instead.\n",
      "\n",
      "WARNING:tensorflow:\n",
      "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
      "For more information, please see:\n",
      "  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
      "  * https://github.com/tensorflow/addons\n",
      "  * https://github.com/tensorflow/io (for I/O related ops)\n",
      "If you depend on functionality not listed there, please file an issue.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/bert-standard/bert/modeling.py:358: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n",
      "WARNING:tensorflow:From /home/husein/bert-standard/bert/modeling.py:671: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.Dense instead.\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/layers/core.py:187: Layer.apply (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `layer.__call__` method instead.\n"
     ]
    }
   ],
   "source": [
    "dimension_output = 2\n",
    "\n",
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Model(\n",
    "    dimension_output,\n",
    ")\n",
    "\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from tiny-bert-keyphrase/model.ckpt-620000\n"
     ]
    }
   ],
   "source": [
    "checkpoint = 'tiny-bert-keyphrase/model.ckpt-620000'\n",
    "saver = tf.train.Saver(var_list = tf.trainable_variables())\n",
    "saver.restore(sess, checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'output-tiny-bert-keyphrase/model.ckpt'"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "saver = tf.train.Saver(tf.trainable_variables())\n",
    "saver.save(sess, 'output-tiny-bert-keyphrase/model.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Placeholder',\n",
       " 'Placeholder_1',\n",
       " 'Placeholder_2',\n",
       " 'Placeholder_3',\n",
       " 'Placeholder_4',\n",
       " 'bert/bert/embeddings/word_embeddings',\n",
       " 'bert/bert/embeddings/token_type_embeddings',\n",
       " 'bert/bert/embeddings/position_embeddings',\n",
       " 'bert/bert/embeddings/LayerNorm/gamma',\n",
       " 'bert/bert/encoder/layer_0/attention/self/query/kernel',\n",
       " 'bert/bert/encoder/layer_0/attention/self/query/bias',\n",
       " 'bert/bert/encoder/layer_0/attention/self/key/kernel',\n",
       " 'bert/bert/encoder/layer_0/attention/self/key/bias',\n",
       " 'bert/bert/encoder/layer_0/attention/self/value/kernel',\n",
       " 'bert/bert/encoder/layer_0/attention/self/value/bias',\n",
       " 'bert/bert/encoder/layer_0/attention/self/Softmax',\n",
       " 'bert/bert/encoder/layer_0/attention/output/dense/kernel',\n",
       " 'bert/bert/encoder/layer_0/attention/output/dense/bias',\n",
       " 'bert/bert/encoder/layer_0/attention/output/LayerNorm/gamma',\n",
       " 'bert/bert/encoder/layer_0/intermediate/dense/kernel',\n",
       " 'bert/bert/encoder/layer_0/intermediate/dense/bias',\n",
       " 'bert/bert/encoder/layer_0/output/dense/kernel',\n",
       " 'bert/bert/encoder/layer_0/output/dense/bias',\n",
       " 'bert/bert/encoder/layer_0/output/LayerNorm/gamma',\n",
       " 'bert/bert/encoder/layer_1/attention/self/query/kernel',\n",
       " 'bert/bert/encoder/layer_1/attention/self/query/bias',\n",
       " 'bert/bert/encoder/layer_1/attention/self/key/kernel',\n",
       " 'bert/bert/encoder/layer_1/attention/self/key/bias',\n",
       " 'bert/bert/encoder/layer_1/attention/self/value/kernel',\n",
       " 'bert/bert/encoder/layer_1/attention/self/value/bias',\n",
       " 'bert/bert/encoder/layer_1/attention/self/Softmax',\n",
       " 'bert/bert/encoder/layer_1/attention/output/dense/kernel',\n",
       " 'bert/bert/encoder/layer_1/attention/output/dense/bias',\n",
       " 'bert/bert/encoder/layer_1/attention/output/LayerNorm/gamma',\n",
       " 'bert/bert/encoder/layer_1/intermediate/dense/kernel',\n",
       " 'bert/bert/encoder/layer_1/intermediate/dense/bias',\n",
       " 'bert/bert/encoder/layer_1/output/dense/kernel',\n",
       " 'bert/bert/encoder/layer_1/output/dense/bias',\n",
       " 'bert/bert/encoder/layer_1/output/LayerNorm/gamma',\n",
       " 'bert/bert/encoder/layer_2/attention/self/query/kernel',\n",
       " 'bert/bert/encoder/layer_2/attention/self/query/bias',\n",
       " 'bert/bert/encoder/layer_2/attention/self/key/kernel',\n",
       " 'bert/bert/encoder/layer_2/attention/self/key/bias',\n",
       " 'bert/bert/encoder/layer_2/attention/self/value/kernel',\n",
       " 'bert/bert/encoder/layer_2/attention/self/value/bias',\n",
       " 'bert/bert/encoder/layer_2/attention/self/Softmax',\n",
       " 'bert/bert/encoder/layer_2/attention/output/dense/kernel',\n",
       " 'bert/bert/encoder/layer_2/attention/output/dense/bias',\n",
       " 'bert/bert/encoder/layer_2/attention/output/LayerNorm/gamma',\n",
       " 'bert/bert/encoder/layer_2/intermediate/dense/kernel',\n",
       " 'bert/bert/encoder/layer_2/intermediate/dense/bias',\n",
       " 'bert/bert/encoder/layer_2/output/dense/kernel',\n",
       " 'bert/bert/encoder/layer_2/output/dense/bias',\n",
       " 'bert/bert/encoder/layer_2/output/LayerNorm/gamma',\n",
       " 'bert/bert/encoder/layer_3/attention/self/query/kernel',\n",
       " 'bert/bert/encoder/layer_3/attention/self/query/bias',\n",
       " 'bert/bert/encoder/layer_3/attention/self/key/kernel',\n",
       " 'bert/bert/encoder/layer_3/attention/self/key/bias',\n",
       " 'bert/bert/encoder/layer_3/attention/self/value/kernel',\n",
       " 'bert/bert/encoder/layer_3/attention/self/value/bias',\n",
       " 'bert/bert/encoder/layer_3/attention/self/Softmax',\n",
       " 'bert/bert/encoder/layer_3/attention/output/dense/kernel',\n",
       " 'bert/bert/encoder/layer_3/attention/output/dense/bias',\n",
       " 'bert/bert/encoder/layer_3/attention/output/LayerNorm/gamma',\n",
       " 'bert/bert/encoder/layer_3/intermediate/dense/kernel',\n",
       " 'bert/bert/encoder/layer_3/intermediate/dense/bias',\n",
       " 'bert/bert/encoder/layer_3/output/dense/kernel',\n",
       " 'bert/bert/encoder/layer_3/output/dense/bias',\n",
       " 'bert/bert/encoder/layer_3/output/LayerNorm/gamma',\n",
       " 'bert/bert/pooler/dense/kernel',\n",
       " 'bert/bert/pooler/dense/bias',\n",
       " 'bert/summary',\n",
       " 'bert_1/bert/encoder/layer_0/attention/self/Softmax',\n",
       " 'bert_1/bert/encoder/layer_1/attention/self/Softmax',\n",
       " 'bert_1/bert/encoder/layer_2/attention/self/Softmax',\n",
       " 'bert_1/bert/encoder/layer_3/attention/self/Softmax',\n",
       " 'dense/kernel',\n",
       " 'dense/bias',\n",
       " 'logits']"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "strings = ','.join(\n",
    "    [\n",
    "        n.name\n",
    "        for n in tf.get_default_graph().as_graph_def().node\n",
    "        if ('Variable' in n.op\n",
    "        or 'Placeholder' in n.name\n",
    "        or 'logits' in n.name\n",
    "        or 'alphas' in n.name\n",
    "        or 'summary' in n.name\n",
    "        or 'self/Softmax' in n.name)\n",
    "        and 'Adam' not in n.name\n",
    "        and 'beta' not in n.name\n",
    "        and 'global_step' not in n.name\n",
    "        and 'Identity' not in n.name\n",
    "        and 'Assign' not in n.name\n",
    "    ]\n",
    ")\n",
    "strings.split(',')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def freeze_graph(model_dir, output_node_names):\n",
    "\n",
    "    if not tf.gfile.Exists(model_dir):\n",
    "        raise AssertionError(\n",
    "            \"Export directory doesn't exists. Please specify an export \"\n",
    "            'directory: %s' % model_dir\n",
    "        )\n",
    "\n",
    "    checkpoint = tf.train.get_checkpoint_state(model_dir)\n",
    "    input_checkpoint = checkpoint.model_checkpoint_path\n",
    "\n",
    "    absolute_model_dir = '/'.join(input_checkpoint.split('/')[:-1])\n",
    "    output_graph = absolute_model_dir + '/frozen_model.pb'\n",
    "    clear_devices = True\n",
    "    with tf.Session(graph = tf.Graph()) as sess:\n",
    "        saver = tf.train.import_meta_graph(\n",
    "            input_checkpoint + '.meta', clear_devices = clear_devices\n",
    "        )\n",
    "        saver.restore(sess, input_checkpoint)\n",
    "        output_graph_def = tf.graph_util.convert_variables_to_constants(\n",
    "            sess,\n",
    "            tf.get_default_graph().as_graph_def(),\n",
    "            output_node_names.split(','),\n",
    "        )\n",
    "        with tf.gfile.GFile(output_graph, 'wb') as f:\n",
    "            f.write(output_graph_def.SerializeToString())\n",
    "        print('%d ops in the final graph.' % len(output_graph_def.node))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from output-tiny-bert-keyphrase/model.ckpt\n",
      "WARNING:tensorflow:From <ipython-input-14-9a7215a4e58a>:23: convert_variables_to_constants (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.compat.v1.graph_util.convert_variables_to_constants`\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/framework/graph_util_impl.py:277: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.compat.v1.graph_util.extract_sub_graph`\n",
      "INFO:tensorflow:Froze 73 variables.\n",
      "INFO:tensorflow:Converted 73 variables to const ops.\n",
      "1553 ops in the final graph.\n"
     ]
    }
   ],
   "source": [
    "freeze_graph('output-tiny-bert-keyphrase', strings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_graph(frozen_graph_filename):\n",
    "    with tf.gfile.GFile(frozen_graph_filename, 'rb') as f:\n",
    "        graph_def = tf.GraphDef()\n",
    "        graph_def.ParseFromString(f.read())\n",
    "    with tf.Graph().as_default() as graph:\n",
    "        tf.import_graph_def(graph_def)\n",
    "    return graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = load_graph('output-tiny-bert-keyphrase/frozen_model.pb')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "transforms = ['add_default_attributes',\n",
    "             'remove_nodes(op=Identity, op=CheckNumerics, op=Dropout)',\n",
    "             'fold_batch_norms',\n",
    "             'fold_old_batch_norms',\n",
    "             'quantize_weights(fallback_min=-10, fallback_max=10)',\n",
    "             'strip_unused_nodes',\n",
    "             'sort_by_execution_order']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.tools.graph_transforms import TransformGraph\n",
    "tf.set_random_seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-20-3ba16c77f9fe>:4: FastGFile.__init__ (from tensorflow.python.platform.gfile) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.gfile.GFile.\n"
     ]
    }
   ],
   "source": [
    "pb = 'output-tiny-bert-keyphrase/frozen_model.pb'\n",
    "\n",
    "input_graph_def = tf.GraphDef()\n",
    "with tf.gfile.FastGFile(pb, 'rb') as f:\n",
    "    input_graph_def.ParseFromString(f.read())\n",
    "    \n",
    "inputs = ['Placeholder', 'Placeholder_1', 'Placeholder_2', 'Placeholder_3',]\n",
    "outputs = ['bert/summary', 'logits']\n",
    "\n",
    "transformed_graph_def = TransformGraph(input_graph_def, \n",
    "                                           inputs,\n",
    "                                           outputs, transforms)\n",
    "\n",
    "with tf.gfile.GFile(f'{pb}.quantized', 'wb') as f:\n",
    "    f.write(transformed_graph_def.SerializeToString())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
